import torch
import random
import torch.nn.functional as F

class DiffAug:
    def __init__(self, config, policy, crop_size=4, angle = 1, noise_ratio = 0.1, translation_ratio = 0.125, cutout_ratio = 0.3):
        self.config = config
        self.AUGMENT_FNS = {
            'color': [self.rand_brightness, self.rand_saturation, self.rand_contrast],
            'translation': [self.rand_translation],
            'cutout': [self.rand_cutout],
            'flip': [self.rand_flip],
            'rotate': [self.rand_rot],
            'crop': [self.rand_crop],
            'noise': [self.random_noise],
        }
        self.policy = policy
        self.crop_size = crop_size
        self.angle = angle
        self.noise_ratio = noise_ratio
        self.translation_ratio = translation_ratio
        self.cutout_ratio = cutout_ratio

    def DiffAugment(self, x, cut=False):
        if cut:
            x = self.rand_cutout(x)
            policy = self.policy.replace('cutout', '')
        else:
            policy = self.policy
            # if both cutout and crop are in policy, pick only one for random 50%
            if 'cutout' in policy and 'crop' in policy:
                policy = policy.replace('cutout', '') if random.random() < 0.5 else policy.replace('crop', '')

        channels_first = self.config.get('channels_first', True)
        if not channels_first:
            x = x.permute(0, 3, 1, 2)
        for p in [policy.strip() for policy in policy.split(',')]:
            for f in self.AUGMENT_FNS.get(p, []):
                x = f(x)  # Apply the augmentation function
        if not channels_first:
            x = x.permute(0, 2, 3, 1)
        x = x.contiguous()
        return x

    def random_noise(self, x):
        noise = torch.randn_like(x) * self.noise_ratio

        noisy_images = x + noise

        return noisy_images

    def rand_rot(self, image): 
        angle = self.angle
        angle = [-angle, angle]
        angleidx = random.randint(0, 1)
        angle = angle[angleidx]
        angle = torch.tensor([angle], dtype=torch.float32, device=image.device)

        theta = torch.tensor([
            [torch.cos(angle), torch.sin(-angle), 0],
            [torch.sin(angle), torch.cos(angle), 0]
        ], dtype=torch.float32, device=image.device).repeat(image.shape[0], 1, 1)  

        grid = F.affine_grid(theta, image.size(), align_corners=True)
        output = F.grid_sample(image, grid, align_corners=True)

        return output

    def rand_crop(self, tensor):
        crop_size = self.crop_size
        B, C, H, W = tensor.shape
        crop_size = H - crop_size
        
        top = random.randint(0, H - crop_size)
        left = random.randint(0, W - crop_size)
        
        cropped = tensor[:, :, top:top+crop_size, left:left+crop_size]
        
        pad_top = top
        pad_bottom = H - (top + crop_size)
        pad_left = left
        pad_right = W - (left + crop_size)
        
        output = F.pad(cropped, (pad_left, pad_right, pad_top, pad_bottom), value=0)
        
        return output

    def rand_flip(self, tensor):
        B, C, H, W = tensor.shape

        flip_horizontal = self.config.get('flip_horizontal', True)
        flip_vertical = self.config.get('flip_vertical', False)

        if flip_horizontal:
            indices = torch.arange(W - 1, -1, step=-1, dtype=torch.long, device=tensor.device)
            tensor = tensor.index_select(3, indices)

        if flip_vertical:
            indices = torch.arange(H - 1, -1, step=-1, dtype=torch.long, device=tensor.device)
            tensor = tensor.index_select(2, indices)

        return tensor



    def rand_brightness(self, x):
        x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 1.0)
        return x

    def rand_saturation(self, x):
        x_mean = x.mean(dim=1, keepdim=True)
        x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
        return x

    def rand_contrast(self, x):
        x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
        x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
        return x

    def rand_translation(self, x):
        ratio = self.translation_ratio
        shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
        translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
        translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
        grid_batch, grid_x, grid_y = torch.meshgrid(
            torch.arange(x.size(0), dtype=torch.long, device=x.device),
            torch.arange(x.size(2), dtype=torch.long, device=x.device),
            torch.arange(x.size(3), dtype=torch.long, device=x.device),
        )
        grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
        grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
        x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
        x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous()
        return x

    def rand_cutout(self, x, ratio_flag = False):
        if ratio_flag:
            ratio = 0.5
        else:
            ratio = self.cutout_ratio
        cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
        offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
        offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
        grid_batch, grid_x, grid_y = torch.meshgrid(
            torch.arange(x.size(0), dtype=torch.long, device=x.device),
            torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
            torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
        )
        grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
        grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
        mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
        mask[grid_batch, grid_x, grid_y] = 0
        x = x * mask.unsqueeze(1)
        return x